import torch
import time
import pandas as pd
from datetime import datetime
from contextlib import contextmanager


class GPUPowerMonitor:
    def __init__(self, gpu_id=0, sampling_rate=0.1):
        """
        Initialize GPU power monitor

        Args:
            gpu_id (int): ID of GPU to monitor
            sampling_rate (float): How frequently to sample power in seconds
        """
        self.gpu_id = gpu_id
        self.sampling_rate = sampling_rate
        self.measurements = []
        self.running = False

        # Verify NVIDIA GPU is available
        assert torch.cuda.is_available(), "CUDA GPU not available"
        assert gpu_id < torch.cuda.device_count(), f"GPU {gpu_id} not found"

    def _record_power(self):
        """Records continuous power measurements while self.running is True"""
        while self.running:
            # Get power draw in watts
            power = torch.cuda.power_draw(self.gpu_id)

            self.measurements.append({
                'timestamp': datetime.now(),
                'power_watts': power
            })
            time.sleep(self.sampling_rate)

    @contextmanager
    def measure(self):
        """Context manager to measure GPU power consumption"""
        import threading

        try:
            self.running = True
            self.measurements = []

            # Start measurement thread
            thread = threading.Thread(target=self._record_power)
            thread.start()

            # Yield control back to the code being measured
            yield

        finally:
            self.running = False
            thread.join()

    def get_statistics(self):
        """Calculate statistics from the power measurements"""
        if not self.measurements:
            return {}

        df = pd.DataFrame(self.measurements)

        stats = {
            'duration_seconds': (df['timestamp'].max() - df['timestamp'].min()).total_seconds(),
            'mean_power_watts': df['power_watts'].mean(),
            'max_power_watts': df['power_watts'].max(),
            'min_power_watts': df['power_watts'].min(),
            'total_energy_joules': df['power_watts'].mean() * len(df) * self.sampling_rate,
            'samples_collected': len(df)
        }

        return stats

    def reset(self):
        """Clear all measurements"""
        self.measurements = []
